#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/3/6 16:34
# @Author  : wanghaoran
# @File    : utils.py
import numpy as np


def calculate_distance(v1, v2, distance):
    if distance == 'L2':
        return round(np.linalg.norm(np.array(v1)-np.array(v2)), 4)
    if distance == 'IP':
        return round(np.dot(np.array(v1), np.array(v2)), 4)
